############ Used Libraries and Building Block Functions ############
######################### Libraries #########################
library(tidyr)
library(dplyr)
library(ggplot2)
library(MASS)
library(ggh4x)
library(Rcpp)
library(RcppArmadillo)
library(RcppParallel)
library(Rmosek)
library(Matrix)
library(RColorBrewer)
##############################################################
################### IPW Building Blocks #################

# 1. Uniform Kernel
kernel_unif = function(W, w, h = 0.5) {
  W_matrix = as.matrix(W)
  w_vector = as.numeric(w)
  diff = abs(W_matrix - matrix(w_vector, nrow(W_matrix), ncol(W_matrix), byrow = TRUE))
  condition = (diff <= h )
  K = apply(condition, 1, all)
  return(K)
}

# 2. Calculating IPW 
ipw_numerator = function(W, w, Y, A, pi_hat, h = 0.5) {
  K = kernel_unif(W, w, h = h)
  terms = (A * Y / pi_hat) - ((1 - A) * Y / (1 - pi_hat))
  weighted_terms = terms * K
  numerator = mean(weighted_terms)
  return(numerator)
}

ipw_denominator = function(W, w, h = 0.5) {
  K = kernel_unif(W, w, h = h)
  denominator = mean(K)
  return(denominator)
}

ipw_estimate = function(W, Y, A, pi_hat, h = 0.5) {
  n = nrow(W)
  tau_ipw = numeric(n)
  for (i in 1:n) {
    w = as.numeric(W[i, ])
    numerator = ipw_numerator(W, w, Y, A, pi_hat, h = h)
    denominator = ipw_denominator(W, w, h = h)
    if (denominator == 0) {
      cat("Warning: Denominator is zero for i =", i, "\n")
      print(list(
        w = w,
        numerator = numerator,
        denominator = denominator
      ))
    }
    tau_ipw[i] = numerator / denominator
  }
  return(tau_ipw)
}
##############################################################



############# IPW Method ###############

IPW_method = function(df, W){
  ## Propensity score vector
  propensity_logit = glm(A ~ ., data = cbind(W, A = df$A), family = binomial)
  propensity_score = predict(propensity_logit, type = "response")
  # Make sure propensity is between 0.001~0.999
  propensity_score = pmax(pmin(propensity_score, 0.999), 0.001)
  # Calculate IPW 
  tau_ipw = ipw_estimate(W, df$Y, df$A, propensity_score, h = 0.5)
  finaldf = data.frame(W=W, A=df$A, pi.hat=propensity_score, tau.w=tau_ipw) %>%
    mutate(group=ifelse(tau.w>0,"S1","S2"))
  fig2df = data.frame(df, pi.hat=propensity_score) %>%
    mutate(group = finaldf$group) %>%
    filter(group == "S1") %>%
    dplyr::select(X1,X2,X3,X4,X5,X6,pi.hat,A)
  fig3df = data.frame(df, pi.hat=propensity_score) %>% 
    mutate(group = finaldf$group) %>%
    filter(group == "S2") %>%
    dplyr::select(X1,X2,X3,X4,X5,X6,pi.hat,A)
  return(list(finaldf = finaldf, covs1df = fig2df, covs2df = fig3df))
}

#### OPTIMIZATION METHOD (Mean, RBF Kernel)

Weight.via.Opt = function(df, W, method, lambda = 1, delta = .05, delta_prime = 10,
                          w.max = 1, w.min = 0, M = 1e4) {
  F = df$F; A = df$A; Y = df$Y
  # method = "mean", "KR50", "KR100"
  if (method == "mean") {
    B.X = as.matrix(W[ ,c(1:6)])
    colB.X = ncol(B.X)
  }
  else if (method == "KR50" || method == "KR100"){
    sourceCpp("kernelfunc/RBF_kernel_C_parallel.cpp")
    W.matrix = as.matrix(W[,1:6]) 
    stv = 1:nrow(W.matrix)
    K = RBF_kernel_C_parallel(W.matrix, length(stv), stv, gamma=0.01)  # RBF kernel gram matrix
    eigen.num = as.integer(substr(method, start = 3, stop = nchar(method)))
    eigen.K = eigen(K)
    K.V = eigen.K$vectors[,1:eigen.num]
    K.Sigma = diag(sqrt(eigen.K$values[1:eigen.num]))
    K.reduced = K.V %*% K.Sigma
    B.X = K.reduced ; colB.X = ncol(B.X)
  }
  else{
    cat("\nWRONG METHOD! (mean / KR50 / KR100) ALLOWED.\n")
  }
  n = dim(df)[1]
  m = min(30, round(n/10))
  ### MICP FORMULATION
  I.nn.mat <- as(.symDiagonal(n=n, x=1.), "dgCMatrix")
  Zero.nn.mat <- as(matrix(0,n,n), "dgCMatrix")
  
  prob <- list(sense="min")
  prob$c <- c(rep(0,4*n), -1, lambda, 0)
  prob$bx <- rbind(blx=rep(0,4*n+3), bux=c(rep(1,4*n), rep(Inf,2), 1))
  
  # the non-conic part of the problem.
  prob$A <- Matrix(rbind(cbind(as(matrix(0,n,2*n), "dgCMatrix"), I.nn.mat, I.nn.mat, rep(0,n), rep(0,n), rep(0,n)),
                         c(rep(0,2*n), rep(1,n), rep(0,n+3)),
                         c(rep(0,3*n), rep(1,n), rep(0,3)),
                         c(A, rep(0,3*n+3)),
                         c(1-A, rep(0,3*n+3)),
                         c(rep(0,n), A, rep(0,2*n+3)),
                         c(rep(0,n), 1-A, rep(0,2*n+3)),
                         cbind(I.nn.mat, as(matrix(0,n,n), "dgCMatrix"), -w.max*I.nn.mat, as(matrix(0,n,n+3), "dgCMatrix")),
                         cbind(as(matrix(0,n,n), "dgCMatrix"), I.nn.mat, as(matrix(0,n,n), "dgCMatrix"), -w.max*I.nn.mat, rep(0,n), rep(0,n), rep(0,n)),
                         cbind(t(as.matrix((2*A-1) * B.X)), as(matrix(0,colB.X,3*n+3), "dgCMatrix")),
                         cbind(as(matrix(0,colB.X,n), "dgCMatrix"), t(as.matrix((2*A-1) * B.X)), as(matrix(0,colB.X,2*n+3), "dgCMatrix")),
                         c((2*A-1)*Y, -(2*A-1)*Y, rep(0, 2*n), -1, 0, M),
                         c(-(2*A-1)*Y, (2*A-1)*Y, rep(0, 2*n), -1, 0, -M),
                         c((2*A-1)*Y, -(2*A-1)*Y, rep(0, 2*n), -1, 0, 0),
                         c(-(2*A-1)*Y, (2*A-1)*Y, rep(0, 2*n), -1, 0, 0),
                         c(rep(0,n), rep(0,n), 2*F-1, rep(0,n), 0, 0, 0),
                         c(rep(0,n), rep(0,n), rep(0,n), 2*F-1, 0, 0, 0)
  ), sparse = TRUE)
  
  prob$bc <- rbind(blc=c(rep(1,n), rep(m,2), rep(1,4), rep(-Inf,2*n), rep(-delta,2*colB.X), 0, -M, -Inf, -Inf, -delta_prime, -delta_prime),
                   buc=c(rep(1,n), rep(n,2), rep(1,4), rep(0,2*n), rep(delta,2*colB.X), Inf, Inf, 0, 0, delta_prime, delta_prime)
  )
  
  
  
  prob$F <- Matrix(rbind(
    c(rep(0,4*n+1), 1, 0),
    rep(0,4*n+3),
    cbind(I.nn.mat, I.nn.mat, as(matrix(0,n,2*n+3), "dgCMatrix"))
  ), sparse = TRUE)
  
  
  prob$g <- c(0, 1/2, rep(0,n)
  )
  prob$cones <- matrix(list("RQUAD", n+2, NULL), nrow=3, ncol=1)
  rownames(prob$cones) <- c("type","dim","conepar")
  
  # Specify the integer constraints
  prob$intsub <- c(seq(2*n+1, 4*n, by=1), 4*n+3)
  # solution
  prob$dparam = list(MIO_MAX_TIME = 2000) # 시간제한 2000초
  r <- mosek(prob)
  
  S.ind <- r$sol$int$xx[seq(2*n+1, 4*n, by=1)]
  w.opt <- r$sol$int$xx[1:(2*n)]
  w.1 <- w.opt[1:n]
  w.2 <- w.opt[(n+1):(2*n)]
  S.1 <- S.ind[1:n]
  S.2 <- S.ind[(n+1):(2*n)]
  if (sum(S.2) > sum(S.1)){
    temp.w <- w.1
    w.1 <- w.2
    w.2 <- temp.w
    
    # S.1 <-> S.2 swap
    temp.S <- S.1
    S.1 <- S.2
    S.2 <- temp.S
  }
  finaldf = data.frame(W=W, w1=w.1, w2= w.2, S1 = round(S.1), S2= round(S.2)) %>%
    mutate(group=ifelse(S1 == 1,"S1","S2")) %>%
    dplyr::select(W.F, w1, w2, S1, S2, group)
  fig2df = data.frame(df, w.1, w.2, S1= round(S.1), S2= round(S.2))
  fig2df$weight = ifelse(fig2df$S1 == 1, fig2df$w.1, fig2df$w.2)
  fig2df = fig2df %>% 
    filter(S1 == 1) %>%
    dplyr::select(X1,X2,X3,X4,X5,X6,weight,A)
  fig3df = data.frame(df, w.1, w.2, S1= round(S.1), S2= round(S.2))
  fig3df$weight = ifelse(fig3df$S1 == 1, fig3df$w.1, fig3df$w.2)
  fig3df = fig3df %>% 
    filter(S2 == 1) %>%
    dplyr::select(X1,X2,X3,X4,X5,X6,weight,A)
  return(list(finaldf = finaldf, covs1df = fig2df, covs2df = fig3df))
}










######### Visualization ##########

## PLOT 1 (FAIRNESS DISTRIBUTION)
fairness_vis = function(finaldf){
  
  plot_data = finaldf %>%
    filter(group %in% c("S1", "S2")) %>% 
    dplyr::select(W.F, group) %>%
    mutate(W.F = as.factor(W.F))  
  ggplot(plot_data, aes(x = group, fill = W.F)) +
    geom_bar(position = "fill") + 
    scale_y_continuous(labels = scales::percent) +  
    scale_x_discrete(labels = c("S1" = expression(S[1]), 
                                "S2" = expression(S[2]))) +
    labs(
      x = "Subgroup",
      y = "Proportion (%)",
      fill = "F"
    ) +
    theme_minimal() +
    theme(legend.position = "right",
          legend.key.size = unit(1.5, "lines"),
          legend.title = element_text(size=14),    #size
          legend.text = element_text(size=12),
          axis.text.x = element_text(size = 16)) #axis size
}


## PLOT 2,3 (Covariate Balance for each subgroup)
covbal_vis = function(result, subgroup.no, lb, ub){
  if (subgroup.no == 1){
    figdf = result$covs1df
    plot.namekey = "(S1)"
    palette.key = "Set1"
  }
  else if (subgroup.no == 2){
    figdf = result$covs2df
    plot.namekey = "(S2)"
    palette.key = "Set2"
  }
  else {
    cat("\nNO SUBGROUP EXISTS!\n")
  }
  A1 = figdf[figdf$A == 1, 1:7]
  A0 = figdf[figdf$A == 0, 1:7]
  if ("pi.hat" %in% colnames(figdf)){
    A1 = A1 %>%
      mutate(across(X1:X6, ~ .x / pi.hat, .names = "weighted.{.col}"))
    weighted.A1 = A1[, c(8:13)]
    A0 = A0 %>%
      mutate(across(X1:X6, ~ .x / (1 - pi.hat), .names = "weighted.{.col}"))
    weighted.A0 = A0[, c(8:13)]
  }
  else if ("weight" %in% colnames(figdf)){
    A1 = A1 %>%
      mutate(across(X1:X6, ~ .x * weight, .names = "weighted.{.col}"))
    weighted.A1 = A1[,c(8:13)]
    A0 = A0 %>%
      mutate(across(X1:X6, ~ .x * weight, .names = "weighted.{.col}"))
    weighted.A0 = A0[,c(8:13)]
  }
  
  plot_data = bind_rows(
    mutate(weighted.A1, group = "A=1"),
    mutate(weighted.A0, group = "A=0")
  )
  
  plot_data_long = plot_data %>%
    pivot_longer(cols = starts_with("weighted"), names_to = "variable", values_to = "value") %>%
    mutate(variable = gsub("weighted\\.", "", variable))
  
  plot_data_filtered = plot_data_long %>%
    group_by(variable) %>%
    filter(value > quantile(value, lb) & value < quantile(value, ub)) %>%
    ungroup() %>%
    mutate(variable = factor(variable, levels = c("X1", "X2", "X3", "X4", "X5", "X6")))
  
  ggplot(plot_data_filtered, aes(x = value, fill = group, color = group)) +
    geom_density(alpha = 0.3) +
    facet_wrap(~ variable, scales = "free", ncol = 3) +
    #### s1 : direction = 1 , s2 : direction = -1
    scale_fill_brewer(palette = palette.key, direction = -2 * subgroup.no + 3 ) +  # s2에 Set2 팔레트 적용
    scale_color_brewer(palette = palette.key, direction = -2 * subgroup.no + 3) +
    labs(
      fill = "Group",
      color = "Group"
    ) +
    theme_minimal() +
    theme(
      # plot.title 
      plot.title = element_blank(),
      # 축 레이블(숫자) 제거
      axis.text.x = element_blank(),
      axis.text.y = element_blank(),
      # 만약 축 눈금 자체도 없애고 싶다면
      #axis.ticks = element_blank(),
      
      legend.position = "bottom",
      legend.key.size = unit(1.5, "lines"),
      legend.title = element_text(size=14),    
      legend.text = element_text(size=12),
      strip.text = element_text(size = 12)
    )
  
}

###### 
